Tracking MLJ experiment using MLFlowClient.jl
In this case, the max_depth
hyperparameter is being tuned. MLFlow will track the entire parameter as an array, and receive the accuracy for each model.
md"""
# Tracking MLJ experiment using MLFlowClient.jl
In this case, the `max_depth` hyperparameter is being tuned. MLFlow will track the entire parameter as an array, and receive the accuracy for each model.
"""
Activating project at `~/git/ds_portfolio/notebooks/mlflow_with_mlj`
using MLJ , DataFrames
Data ingestion
md"### Data ingestion"
sepal_length | sepal_width | petal_length | petal_width | target | |||
---|---|---|---|---|---|---|---|
Float64 | Float64 | Float64 | Float64 | CategoricalValue | |||
1 | 5.1 | 3.5 | 1.4 | 0.2 | "setosa" | ||
2 | 4.9 | 3.0 | 1.4 | 0.2 | "setosa" | ||
3 | 4.7 | 3.2 | 1.3 | 0.2 | "setosa" | ||
4 | 4.6 | 3.1 | 1.5 | 0.2 | "setosa" | ||
5 | 5.0 | 3.6 | 1.4 | 0.2 | "setosa" | ||
6 | 5.4 | 3.9 | 1.7 | 0.4 | "setosa" | ||
7 | 4.6 | 3.4 | 1.4 | 0.3 | "setosa" | ||
8 | 5.0 | 3.4 | 1.5 | 0.2 | "setosa" | ||
9 | 4.4 | 2.9 | 1.4 | 0.2 | "setosa" | ||
10 | 4.9 | 3.1 | 1.5 | 0.1 | "setosa" | ||
150 | 5.9 | 3.0 | 5.1 | 1.8 | "virginica" |
┌──────────────┬───────────────┬──────────────────────────────────┐
│ names │ scitypes │ types │
├──────────────┼───────────────┼──────────────────────────────────┤
│ sepal_length │ Continuous │ Float64 │
│ sepal_width │ Continuous │ Float64 │
│ petal_length │ Continuous │ Float64 │
│ petal_width │ Continuous │ Float64 │
│ target │ Multiclass{3} │ CategoricalValue{String, UInt32} │
└──────────────┴───────────────┴──────────────────────────────────┘
sepal_length | sepal_width | petal_length | petal_width | target | |||
---|---|---|---|---|---|---|---|
Float64 | Float64 | Float64 | Float64 | CategoricalValue | |||
1 | 6.4 | 3.2 | 4.5 | 1.5 | "versicolor" | ||
2 | 6.7 | 3.1 | 4.4 | 1.4 | "versicolor" | ||
3 | 4.9 | 3.0 | 1.4 | 0.2 | "setosa" | ||
4 | 6.3 | 2.5 | 4.9 | 1.5 | "versicolor" | ||
5 | 6.2 | 2.2 | 4.5 | 1.5 | "versicolor" | ||
6 | 5.7 | 3.8 | 1.7 | 0.3 | "setosa" | ||
7 | 6.9 | 3.2 | 5.7 | 2.3 | "virginica" | ||
8 | 4.4 | 3.0 | 1.3 | 0.2 | "setosa" | ||
9 | 6.1 | 2.8 | 4.0 | 1.3 | "versicolor" | ||
10 | 5.4 | 3.4 | 1.7 | 0.2 | "setosa" | ||
120 | 5.7 | 2.6 | 3.5 | 1.0 | "versicolor" |
sepal_length | sepal_width | petal_length | petal_width | target | |||
---|---|---|---|---|---|---|---|
Float64 | Float64 | Float64 | Float64 | CategoricalValue | |||
1 | 5.4 | 3.9 | 1.3 | 0.4 | "setosa" | ||
2 | 6.4 | 3.1 | 5.5 | 1.8 | "virginica" | ||
3 | 5.8 | 2.8 | 5.1 | 2.4 | "virginica" | ||
4 | 6.3 | 2.3 | 4.4 | 1.3 | "versicolor" | ||
5 | 5.2 | 2.7 | 3.9 | 1.4 | "versicolor" | ||
6 | 6.5 | 2.8 | 4.6 | 1.5 | "versicolor" | ||
7 | 6.8 | 2.8 | 4.8 | 1.4 | "versicolor" | ||
8 | 4.9 | 3.1 | 1.5 | 0.1 | "setosa" | ||
9 | 5.6 | 2.8 | 4.9 | 2.0 | "virginica" | ||
10 | 6.4 | 2.9 | 4.3 | 1.3 | "versicolor" | ||
30 | 5.6 | 2.5 | 3.9 | 1.1 | "versicolor" |
"versicolor"
"versicolor"
"setosa"
"versicolor"
"versicolor"
"setosa"
"virginica"
"setosa"
"versicolor"
"versicolor"
sepal_length | sepal_width | petal_length | petal_width | |||
---|---|---|---|---|---|---|
Float64 | Float64 | Float64 | Float64 | |||
1 | 6.4 | 3.2 | 4.5 | 1.5 | ||
2 | 6.7 | 3.1 | 4.4 | 1.4 | ||
3 | 4.9 | 3.0 | 1.4 | 0.2 | ||
4 | 6.3 | 2.5 | 4.9 | 1.5 | ||
5 | 6.2 | 2.2 | 4.5 | 1.5 | ||
6 | 5.7 | 3.8 | 1.7 | 0.3 | ||
7 | 6.9 | 3.2 | 5.7 | 2.3 | ||
8 | 4.4 | 3.0 | 1.3 | 0.2 | ||
9 | 6.1 | 2.8 | 4.0 | 1.3 | ||
10 | 5.4 | 3.4 | 1.7 | 0.2 | ||
120 | 5.7 | 2.6 | 3.5 | 1.0 |
"setosa"
"virginica"
"virginica"
"versicolor"
"versicolor"
"versicolor"
"versicolor"
"setosa"
"virginica"
"versicolor"
sepal_length | sepal_width | petal_length | petal_width | |||
---|---|---|---|---|---|---|
Float64 | Float64 | Float64 | Float64 | |||
1 | 5.4 | 3.9 | 1.3 | 0.4 | ||
2 | 6.4 | 3.1 | 5.5 | 1.8 | ||
3 | 5.8 | 2.8 | 5.1 | 2.4 | ||
4 | 6.3 | 2.3 | 4.4 | 1.3 | ||
5 | 5.2 | 2.7 | 3.9 | 1.4 | ||
6 | 6.5 | 2.8 | 4.6 | 1.5 | ||
7 | 6.8 | 2.8 | 4.8 | 1.4 | ||
8 | 4.9 | 3.1 | 1.5 | 0.1 | ||
9 | 5.6 | 2.8 | 4.9 | 2.0 | ||
10 | 6.4 | 2.9 | 4.3 | 1.3 | ||
30 | 5.6 | 2.5 | 3.9 | 1.1 |
MLFlowClient setup
md"""
### MLFlowClient setup
"""
using MLFlowClient
MLFlowClient.MLFlow(
baseuri = "http://localhost:5000",
apiversion = 2.0
)
mlf = MLFlow("http://localhost:5000")
MLFlowClient.MLFlowExperiment(
name = "iris_classification",
lifecycle_stage = "active",
experiment_id = 645006071875603648,
tags = missing,
artifact_location = "/home/pebeto/git/ds_portfolio/notebooks/mlflow_with_mlj/iris-artifacts"
)
Modeling
md"""
### Modeling
"""
MLJDecisionTreeInterface.DecisionTreeClassifier
For silent loading, specify `verbosity=0`.
import MLJDecisionTreeInterface ✔
DecisionTreeClassifier(
max_depth = -1,
min_samples_leaf = 1,
min_samples_split = 2,
min_purity_increase = 0.0,
n_subfeatures = 0,
post_prune = false,
merge_purity_threshold = 1.0,
display_depth = 5,
feature_importance = :impurity,
rng = Random._GLOBAL_RNG())
ProbabilisticTunedModel(
model = DecisionTreeClassifier(
max_depth = -1,
min_samples_leaf = 1,
min_samples_split = 2,
min_purity_increase = 0.0,
n_subfeatures = 0,
post_prune = false,
merge_purity_threshold = 1.0,
display_depth = 5,
feature_importance = :impurity,
rng = Random._GLOBAL_RNG()),
tuning = Grid(
goal = nothing,
resolution = 10,
shuffle = true,
rng = Random._GLOBAL_RNG()),
resampling = CV(
nfolds = 6,
shuffle = false,
rng = Random._GLOBAL_RNG()),
measure = MLJBase.Measure[Accuracy(), LogLoss(tol = 2.220446049250313e-16), MisclassificationRate(), BrierScore()],
weights = nothing,
class_weights = nothing,
operation = nothing,
range = NumericRange(2 ≤ max_depth ≤ 10; origin=6.0, unit=4.0),
selection_heuristic = MLJTuning.NaiveSelection(nothing),
train_best = true,
repeats = 1,
n = nothing,
acceleration = CPU1{Nothing}(nothing),
acceleration_resampling = CPU1{Nothing}(nothing),
check_measure = true,
cache = true)
model = TunedModel(
resampling=CV(),
tuning=Grid(),
measure=[accuracy, log_loss, misclassification_rate, brier_score]
)
untrained Machine; does not cache data
model: ProbabilisticTunedModel(model = DecisionTreeClassifier(max_depth = -1, …), …)
args:
1: Source @344 ⏎ ScientificTypesBase.Table{AbstractVector{ScientificTypesBase.Continuous}}
2: Source @690 ⏎ AbstractVector{ScientificTypesBase.Multiclass{3}}
trained Machine; does not cache data
model: ProbabilisticTunedModel(model = DecisionTreeClassifier(max_depth = -1, …), …)
args:
1: Source @344 ⏎ ScientificTypesBase.Table{AbstractVector{ScientificTypesBase.Continuous}}
2: Source @690 ⏎ AbstractVector{ScientificTypesBase.Multiclass{3}}
Training machine(ProbabilisticTunedModel(model = DecisionTreeClassifier(max_depth = -1, …), …), …).
Attempting to evaluate 9 models.
Evaluating over 9 metamodels: 0%[> ] ETA: N/A Evaluating over 9 metamodels: 11%[==> ] ETA: 0:00:48 Evaluating over 9 metamodels: 22%[=====> ] ETA: 0:00:21 Evaluating over 9 metamodels: 33%[========> ] ETA: 0:00:12 Evaluating over 9 metamodels: 44%[===========> ] ETA: 0:00:08 Evaluating over 9 metamodels: 56%[=============> ] ETA: 0:00:05 Evaluating over 9 metamodels: 67%[================> ] ETA: 0:00:03 Evaluating over 9 metamodels: 78%[===================> ] ETA: 0:00:02 Evaluating over 9 metamodels: 89%[======================> ] ETA: 0:00:01 Evaluating over 9 metamodels: 100%[=========================] Time: 0:00:06
Evaluating
md"### Evaluating"
Accuracy()
LogLoss( tol = 2.22045e-16)
MisclassificationRate()
BrierScore()
0.908333
3.02242
0.0916667
-0.179261
5
Accuracy()
LogLoss( tol = 2.22045e-16)
MisclassificationRate()
BrierScore()
0.933333
2.40291
0.0666667
-0.133333
8
Accuracy()
LogLoss( tol = 2.22045e-16)
MisclassificationRate()
BrierScore()
0.916667
3.00364
0.0833333
-0.166667
6
Accuracy()
LogLoss( tol = 2.22045e-16)
MisclassificationRate()
BrierScore()
0.925
2.70327
0.075
-0.15
7
Accuracy()
LogLoss( tol = 2.22045e-16)
MisclassificationRate()
BrierScore()
0.925
0.722338
0.075
-0.110157
3
Accuracy()
LogLoss( tol = 2.22045e-16)
MisclassificationRate()
BrierScore()
0.916667
3.00364
0.0833333
-0.166667
10
Accuracy()
LogLoss( tol = 2.22045e-16)
MisclassificationRate()
BrierScore()
0.908333
3.304
0.0916667
-0.183333
9
Accuracy()
LogLoss( tol = 2.22045e-16)
MisclassificationRate()
BrierScore()
0.933333
0.755373
0.0666667
-0.12145
2
Accuracy()
LogLoss( tol = 2.22045e-16)
MisclassificationRate()
BrierScore()
0.925
2.43282
0.075
-0.152475
4
end
end